# encoding: utf-8

"""
@Time: 2021-07-20 11:23 
@Author: Libing Wang
@File: main.py 
@description: 
"""

import os
import shutil
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from utils import *
import cfgs_classifier as cfgs

os.environ["CUDA_VISIBLE_DEVICES"] = '0'


def dis_cfgs(models):
    print("global_cfgs")
    cfgs.show_cfgs(cfgs.global_cfgs)
    print("dataset_cfgs")
    cfgs.show_cfgs(cfgs.dataset_cfgs)
    print("net_cfgs")
    cfgs.show_cfgs(cfgs.net_cfgs)
    print("optimizer_cfgs")
    cfgs.show_cfgs(cfgs.optimizer_cfgs)
    print("saving_cfgs")
    cfgs.show_cfgs(cfgs.saving_cfgs)
    for model in models:
        print(model)


def flatten_label(target):
    label_flatten = []
    label_length = []
    for i in range(target.size()[0]):
        cur_label = target[i].tolist()
        label_flatten += cur_label[: cur_label.index(0) + 1]
        label_length.append(cur_label.index(0) + 1)
    label_flatten = torch.LongTensor(label_flatten)
    label_length = torch.LongTensor(label_length)
    return [label_flatten, label_length]


def train_or_eval(models, state='train'):
    for model in models:
        if state in ["Train", "train"]:
            model.train()
        else:
            model.eval()


def zero_grad(models):
    for model in models:
        model.zero_grad()


def update_param(optimizers, frozen):
    for i in range(len(optimizers)):
        if i not in frozen:
            optimizers[i].step()


def load_dataset():
    train_data_set = cfgs.dataset_cfgs["dataset_train"](**cfgs.dataset_cfgs["dataset_train_args"])
    train_loader = DataLoader(train_data_set, **cfgs.dataset_cfgs["dataloader_train"])
    test_data_set = cfgs.dataset_cfgs["dataset_test"](**cfgs.dataset_cfgs["dataset_test_args"])
    test_loader = DataLoader(test_data_set, **cfgs.dataset_cfgs["dataloader_test"])
    return [train_loader, test_loader]


def create_network():
    model_fe = cfgs.net_cfgs["FE"](**cfgs.net_cfgs["FE_args"])
    cfgs.net_cfgs["CAM_args"]["scales"] = model_fe.need_shapes()
    model_cam = cfgs.net_cfgs["CAM"](**cfgs.net_cfgs["CAM_args"])
    model_dtd = cfgs.net_cfgs["DTD"](**cfgs.net_cfgs["DTD_args"])

    if cfgs.net_cfgs["init_state_dict_fe"] is not None:
        model_fe.load_state_dict(torch.load(cfgs.net_cfgs["init_state_dict_fe"], map_location='cpu'))

    if cfgs.net_cfgs["init_state_dict_cam"] is not None:
        model_cam.load_state_dict(torch.load(cfgs.net_cfgs["init_state_dict_cam"], map_location='cpu'))

    if cfgs.net_cfgs["init_state_dict_dtd"] is not None:
        model_dtd.load_state_dict(torch.load(cfgs.net_cfgs["init_state_dict_dtd"], map_location='cpu'))

    model_fe.to(devices)
    model_cam.to(devices)
    model_dtd.to(devices)

    return [model_fe, model_cam, model_dtd]


def generator_optimizer(models):
    out = []
    scheduler = []
    for i in range(len(models)):
        out.append(cfgs.optimizer_cfgs["optimizer_%d" % i](models[i].parameters(),
                                                           **cfgs.optimizer_cfgs["optimizer_%d_args" % i]))
        scheduler.append(cfgs.optimizer_cfgs["optimizer_%d_scheduler" % i](out[i], **cfgs.optimizer_cfgs[
            "optimizer_%d_scheduler_args" % i]))
    return out, scheduler


def model_eval(test_loader, models, tools):
    train_or_eval(models, 'eval')
    temp = []
    for step, batch_samples in enumerate(test_loader):
        data = batch_samples["image"]
        label = batch_samples["label"]
        target = tools[0].encode(label)
        data = data.to(devices)
        label_flatten, length = tools[1](target)
        target, label_flatten = target.to(devices), label_flatten.to(devices)
        features = models[0](data)
        attention_map, classifier_features = models[1](features)
        # TODO 添加保存 attention map 代码
        # vis_att_map(label, attention_map, cfgs.saving_cfgs["saving_path"], step)
        output, out_length = models[2](features[-1], attention_map, target, length, True, classifier_features)
        pred = tools[0].decode(output, out_length)[0]
        tools[2].add_iter(output, out_length, length, label)
        for i in range(len(label)):
            temp.append("label: %s  ----- pred: %s\n" % (label[i], pred[i]))
    print("********************************************************")
    acc, ar, cer, wer = tools[2].show()
    print("********************************************************")
    train_or_eval(models, "train")
    return acc, ar, cer, wer, temp


def vis_att_map(labels, attention_map, path, step):
    """
    保存attention map, 判断出错的原因是否是由于对齐错误
    attention_map: shape [N, 25, 8, 32]
    """
    for i in range(len(labels)):
        print(labels[i])
        sub_path = os.path.join(path, str(step * 10 + i) + '_' + labels[i])
        if os.path.exists(sub_path):
            shutil.rmtree(sub_path)
        os.makedirs(sub_path)
        for j in range(cfgs.net_cfgs["CAM_args"]["maxT"]):
            plt.imshow(attention_map[i, j, ...].detach().numpy())
            plt.savefig(os.path.join(sub_path, "%d.png" % j))


if __name__ == '__main__':
    devices = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models = create_network()
    dis_cfgs(models)
    optimizers, schedulers = generator_optimizer(models)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(devices)
    train_loader, test_loader = load_dataset()
    print("Preparing Done!")

    train_acc_counter = AccCounter("train acc: ", cfgs.dataset_cfgs["dict_dir"], cfgs.dataset_cfgs["case_sensitive"])
    test_acc_counter = AccCounter("test acc: ", cfgs.dataset_cfgs["dict_dir"], cfgs.dataset_cfgs["case_sensitive"])
    loss_counter = LossCounter()
    character_tool = CharacterTransTool(cfgs.dataset_cfgs["dict_dir"], cfgs.dataset_cfgs["case_sensitive"])

    if cfgs.global_cfgs["state"] is "Test":
        _, _, _, _, temp = model_eval(test_loader, models, [character_tool, flatten_label, test_acc_counter])
        with open(os.path.join(cfgs.saving_cfgs["saving_path"], 'res.txt'), 'a', encoding='utf-8') as fl:
            for item in temp:
                fl.write(item)
        exit()

    total_iters = len(train_loader)
    pre_acc = 0
    for epoch in range(cfgs.global_cfgs["epochs"]):
        for step, batch_samples in enumerate(train_loader):
            train_or_eval(models, 'train')
            data = batch_samples["image"]
            data = data.to(devices)
            label = batch_samples["label"]
            target = character_tool.encode(label)
            label_flatten, length = flatten_label(target)
            target, label_flatten = target.to(devices), label_flatten.to(devices)

            features = models[0](data)
            attention_map, classifier_features = models[1](features)
            output, _ = models[2](features[-1], attention_map, target, length, classifier_features=classifier_features)
            train_acc_counter.add_iter(output, length.long(), length, label)
            loss = criterion(output, label_flatten)
            loss_counter.add_iter(loss)

            zero_grad(models)
            loss.backward()

            """
            nn.utils.clip_grad_norm_(models[0].parameters(), 20, 2)
            nn.utils.clip_grad_norm_(models[1].parameters(), 20, 2)
            nn.utils.clip_grad_norm_(models[2].parameters(), 20, 2)
            """

            update_param(optimizers, frozen=[])

            print('epoch: {}, iter: {} / {}, loss: {}'.format(epoch, step, total_iters, loss_counter.get_loss()))
            train_acc_counter.show()
            print('-----------------------------------------------------------------')

        acc, ar, cer, wer, temp = model_eval(test_loader, models, [character_tool, flatten_label, test_acc_counter])
        if acc >= pre_acc:
            for i in range(len(models)):
                torch.save(models[i].state_dict(),
                           os.path.join(cfgs.saving_cfgs["saving_path"], 'best_model_%d.pth' % i))
            pre_acc = acc
            with open(os.path.join(cfgs.saving_cfgs["saving_path"], 'acc.txt'), 'a', encoding='utf-8') as fl:
                fl.write("epoch: %d   acc: %f   ar: %f  cer: %f  wer: %f \n" % (epoch, acc, ar, cer, wer))
            with open(os.path.join(cfgs.saving_cfgs["saving_path"], 'res.txt'), 'w', encoding='utf-8') as fl:
                for item in temp:
                    fl.write(item)
        update_param(optimizers, frozen=[])
